#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Runner script for the LLaDA GUI application.
"""

import os
import sys
import torch
import traceback

# Make sure we're running from the right directory
current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)

# Make sure we have the correct Python environment
if os.path.exists('./venv/bin/python'):
    python_path = os.path.abspath('./venv/bin/python')
    if sys.executable != python_path:
        print(f"Restarting with the virtual environment Python: {python_path}")
        os.execl(python_path, python_path, *sys.argv)

# Add the current directory to the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# Check if psutil is installed (needed for memory monitor)
try:
    import psutil
except ImportError:
    print("Installing required package: psutil...")
    os.system(f"{sys.executable} -m pip install psutil")
    try:
        import psutil
    except ImportError:
        print("Failed to install psutil. Please run: pip install psutil")
        sys.exit(1)

# Check if required packages are installed
required_packages = [
    'PyQt6',
    'torch',
    'transformers==4.38.2',
    'numpy',
]

missing_packages = []
for package in required_packages:
    try:
        __import__(package.split('==')[0])
    except ImportError:
        missing_packages.append(package)

if missing_packages:
    print(f"Missing required packages: {', '.join(missing_packages)}")
    print("Please install them using:")
    print(f"  {sys.executable} -m pip install {' '.join(missing_packages)}")
    sys.exit(1)

# Clean up any leftover GPU memory
if torch.cuda.is_available():
    try:
        torch.cuda.empty_cache()
        print(f"CUDA is available. Found {torch.cuda.device_count()} device(s).")
        for i in range(torch.cuda.device_count()):
            print(f"  Device {i}: {torch.cuda.get_device_name(i)}")
    except Exception as e:
        print(f"Warning: Error accessing CUDA: {str(e)}")

# Import and run the main application
try:
    # Import directly from the llada_gui.py file
    sys.path.insert(0, current_dir)  # Ensure current dir is in path
    from llada_gui import main
    main()
except Exception as e:
    # Display error in a GUI dialog if possible, otherwise print to console
    error_msg = f"Error starting LLaDA GUI: {str(e)}\n\n{traceback.format_exc()}"
    try:
        from PyQt6.QtWidgets import QApplication, QMessageBox
        app = QApplication([])
        QMessageBox.critical(None, "LLaDA GUI Error", error_msg)
    except:
        print(error_msg)
    sys.exit(1)
